package org.wikipedia.miner.comparison;


import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import java.util.Vector;

import org.wikipedia.miner.model.Wikipedia;
import org.wikipedia.miner.util.CorrelationCalculator;
import org.wikipedia.miner.util.ProgressTracker;

import weka.classifiers.Classifier;
import weka.classifiers.functions.GaussianProcesses;
import weka.core.Instance;

import org.dmilne.weka.wrapper.Dataset;
import org.dmilne.weka.wrapper.Decider;
import org.dmilne.weka.wrapper.DeciderBuilder;
import org.dmilne.weka.wrapper.InstanceBuilder;

public class ConnectionSnippetWeighter {

	
	
	enum Attributes {
		generality,
		inLinkCount,
		outLinkCount,
		isTopic1,
		relatednessToTopic1,
		isTopic2,
		relatednessToTopic2,
		sentenceIndex,
		wordCount,
		isListItem,
		isFromFirstParagraph,
		isAfterHeading
	}
	
	
	private Wikipedia wikipedia ;
	private ArticleComparer cmp ;
	
	private Decider<Attributes,Double> snippetWeighter ;
	private Dataset<Attributes,Double> trainingDataset ;
	
	
	
	@SuppressWarnings("unchecked")
	public ConnectionSnippetWeighter(Wikipedia wikipedia, ArticleComparer cmp) throws Exception {

		this.wikipedia = wikipedia ;
		this.cmp = cmp ;

		snippetWeighter = (Decider<Attributes, Double>) new DeciderBuilder<Attributes>("connectionSnippetWeighter", Attributes.class) 
		.setDefaultAttributeTypeNumeric()
		.setAttributeTypeBoolean(Attributes.isTopic1)
		.setAttributeTypeBoolean(Attributes.isTopic2)
		.setAttributeTypeBoolean(Attributes.isAfterHeading)
		.setAttributeTypeBoolean(Attributes.isListItem)
		.setAttributeTypeBoolean(Attributes.isFromFirstParagraph)
		.setClassAttributeTypeNumeric("snippetWeight")
		.build();

		if (wikipedia.getConfig().getComparisonSnippetModel() != null) 
			this.loadClassifier(wikipedia.getConfig().getComparisonSnippetModel()) ;
	}
	
	
	public double getWeight(ConnectionSnippet snippet) throws Exception {
		
		if (!snippetWeighter.isReady()) {
			//Logger.getLogger(ArticleComparer.class).debug("Article comparison without ml") ;
			//no classifier available, so just return mean of gathered measurements ;
			
			double totalWeight = 0 ;
			
			totalWeight += cmp.getRelatedness(snippet.getSource(), snippet.getTopic1()) ;
			totalWeight += cmp.getRelatedness(snippet.getSource(), snippet.getTopic2()) ;
			
			return totalWeight / 2 ;
			
		} else {
			return snippetWeighter.getDecision(getInstance(snippet)) ;
		}
	}
	
	public void train(Vector<ConnectionSnippet> weightedSnippets) throws Exception {

		trainingDataset = snippetWeighter.createNewDataset() ;

		ProgressTracker pn = new ProgressTracker(weightedSnippets.size(), "training", ConnectionSnippetWeighter.class) ;
		for (ConnectionSnippet snippet: weightedSnippets) {

			if (snippet.getWeight() == null)
				throw new Exception("Training snippet is not weighted") ;
			
			trainingDataset.add(getInstance(snippet)) ;
			
			pn.update() ;
		}
	}
	
	public double test(Vector<ConnectionSnippet> weightedSnippets) throws Exception {

		List<Double> manualWeights = new ArrayList<Double>() ;
		List<Double> autoWeights = new ArrayList<Double>() ;

		ProgressTracker pn = new ProgressTracker(weightedSnippets.size(), "testing", ArticleComparer.class) ;
		for (ConnectionSnippet snippet: weightedSnippets) {
			
			if (snippet.getWeight() == null)
				throw new Exception("Testing snippet is not weighted") ;

			manualWeights.add(snippet.getWeight()) ;
			autoWeights.add(this.getWeight(snippet)) ;

			pn.update() ;
		}

		return CorrelationCalculator.getCorrelation(manualWeights, autoWeights) ;
	}
	
	
	/**
	 * Saves the training data generated by train() to the given file.
	 * The data is saved in WEKA's arff format. 
	 * 
	 * @param file the file to save the training data to
	 * @throws IOException if the file cannot be written to
	 */
	public void saveTrainingData(File file) throws IOException, Exception {
		trainingDataset.save(file) ;
	}

	/**
	 * Loads training data from the given file.
	 * The file must be a valid WEKA arff file. 
	 * 
	 * @param file the file to save the training data to
	 * @throws IOException if the file cannot be read.
	 * @throws Exception if the file does not contain valid training data.
	 */
	public void loadTrainingData(File file) throws IOException, Exception{
		trainingDataset.load(file) ;
	}

	/**
	 * Serializes the classifer and saves it to the given file.
	 * 
	 * @param file the file to save the classifier to
	 * @throws IOException if the file cannot be read
	 */
	public void saveClassifier(File file) throws IOException {
		snippetWeighter.save(file) ;
	}

	/**
	 * Loads the classifier from file
	 * 
	 * @param file 
	 * @throws Exception 
	 */
	public void loadClassifier(File file) throws Exception {
		snippetWeighter.load(file) ;
	}


	/**
	 * 
	 * 
	 * @param classifier
	 * @throws Exception
	 */
	public void buildClassifier(Classifier classifier) throws Exception {

		snippetWeighter.train(classifier, trainingDataset) ;
	}

	/**
	 * 
	 * 
	 * @param classifier
	 * @throws Exception
	 */
	public void buildDefaultClassifier() throws Exception {

		Classifier classifier = new GaussianProcesses() ;
		snippetWeighter.train(classifier, trainingDataset) ;
	}
	
	private Instance getInstance(ConnectionSnippet snippet) throws Exception {
	
		InstanceBuilder<Attributes, Double> ib = snippetWeighter.getInstanceBuilder() ;

		ib.setAttribute(Attributes.generality, snippet.getSource().getGenerality()) ;
		ib.setAttribute(Attributes.inLinkCount, Math.log(snippet.getSource().getDistinctLinksInCount() +1)) ;
		ib.setAttribute(Attributes.outLinkCount, Math.log(snippet.getSource().getDistinctLinksOutCount() +1)) ;
		ib.setAttribute(Attributes.isTopic1, snippet.getSource().getId() == snippet.getTopic1().getId()) ;
		ib.setAttribute(Attributes.relatednessToTopic1, cmp.getRelatedness(snippet.getSource(), snippet.getTopic1())) ;
		ib.setAttribute(Attributes.isTopic2, snippet.getSource().getId() == snippet.getTopic2().getId()) ;
		ib.setAttribute(Attributes.relatednessToTopic2, cmp.getRelatedness(snippet.getSource(), snippet.getTopic2())) ;
		
		ib.setAttribute(Attributes.sentenceIndex, snippet.getSentenceIndex()) ;
		
		StringTokenizer t = new StringTokenizer(snippet.getPlainText()) ;
		ib.setAttribute(Attributes.wordCount, t.countTokens()) ;

		ib.setAttribute(Attributes.isListItem, snippet.isListItem()) ;
		ib.setAttribute(Attributes.isAfterHeading, snippet.followsHeading()) ;
		
		if (snippet.getWeight() != null)
			ib.setClassAttribute(snippet.getWeight()) ;

		return ib.build() ;
	}
}
